import argparse
from mcp.server.fastmcp import FastMCP
from starlette.applications import Starlette
from mcp.server.sse import SseServerTransport
from starlette.requests import Request
from starlette.routing import Mount, Route
from mcp.server import Server
import logging
import uvicorn

# 定义服务器名称
MCP_SERVER_NAME = "modulo-mcp-sse"

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(MCP_SERVER_NAME)

# 初始化 FastMCP 实例
mcp = FastMCP(MCP_SERVER_NAME)

# 定义取模工具
@mcp.tool()
def modulo(a: float, b: float) -> float:
    """
    计算两个数的取模结果。

    参数：
    - a (float)：被除数（必填）
    - b (float)：除数（必填，不能为零）

    返回：
    - float：a % b 的结果
    """
    if b == 0:
        raise ValueError("除数不能为零")
    return a % b

# 创建 Starlette 应用
def create_starlette_app(mcp_server: Server, *, debug: bool = False) -> Starlette:
    """创建一个支持 SSE 的 Starlette 应用，用于运行 MCP 服务器。"""
    sse = SseServerTransport("/messages/")

    async def handle_sse(request: Request) -> None:
        async with sse.connect_sse(
                request.scope,
                request.receive,
                request._send,
        ) as (read_stream, write_stream):
            await mcp_server.run(
                read_stream,
                write_stream,
                mcp_server.create_initialization_options(),
            )

    return Starlette(
        debug=debug,
        routes=[
            Route("/sse", endpoint=handle_sse),
            Mount("/messages/", app=sse.handle_post_message),
        ],
    )

# 主程序入口
if __name__ == "__main__":
    mcp_server = mcp._mcp_server

    # 解析命令行参数
    parser = argparse.ArgumentParser(description='运行基于 SSE 的 MCP 取模服务器')
    parser.add_argument('--host', default='0.0.0.0', help='绑定的主机地址')
    parser.add_argument('--port', type=int, default=18081, help='监听端口')
    args = parser.parse_args()

    # 创建并运行 Starlette 应用
    starlette_app = create_starlette_app(mcp_server, debug=True)
    uvicorn.run(starlette_app, host=args.host, port=args.port)
